﻿using System;
#if !NET
using System.Collections.Generic;
#endif
using System.Globalization;
#if !NET
using System.IO;
using System.Threading.Tasks;
#endif
using System.Net;
using System.Net.Sockets;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Threading;

using Renci.SshNet.Messages;

namespace Renci.SshNet.Common
{
    /// <summary>
    /// Collection of different extension methods.
    /// </summary>
    internal static class Extensions
    {
#pragma warning disable S4136 // Method overloads should be grouped together
        internal static byte[] ToArray(this ServiceName serviceName)
#pragma warning restore S4136 // Method overloads should be grouped together
        {
            switch (serviceName)
            {
                case ServiceName.UserAuthentication:
                    return SshData.Ascii.GetBytes("ssh-userauth");
                case ServiceName.Connection:
                    return SshData.Ascii.GetBytes("ssh-connection");
                default:
                    throw new NotSupportedException(string.Format("Service name '{0}' is not supported.", serviceName));
            }
        }

        internal static ServiceName ToServiceName(this byte[] data)
        {
            var sshServiceName = SshData.Ascii.GetString(data, 0, data.Length);
            switch (sshServiceName)
            {
                case "ssh-userauth":
                    return ServiceName.UserAuthentication;
                case "ssh-connection":
                    return ServiceName.Connection;
                default:
                    throw new NotSupportedException(string.Format("Service name '{0}' is not supported.", sshServiceName));
            }
        }

        internal static BigInteger ToBigInteger(this ReadOnlySpan<byte> data)
        {
#if NET
            return new BigInteger(data, isBigEndian: true);
#else
            var reversed = data.ToArray();
            Array.Reverse(reversed);
            return new BigInteger(reversed);
#endif
        }

        internal static BigInteger ToBigInteger(this byte[] data)
        {
#if NET
            return new BigInteger(data, isBigEndian: true);
#else
            var reversed = new byte[data.Length];
            Buffer.BlockCopy(data, 0, reversed, 0, data.Length);
            Array.Reverse(reversed);
            return new BigInteger(reversed);
#endif
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="BigInteger"/> structure using the SSH BigNum2 Format.
        /// </summary>
        public static BigInteger ToBigInteger2(this byte[] data)
        {
#if NET
            return new BigInteger(data, isBigEndian: true, isUnsigned: true);
#else
            if ((data[0] & (1 << 7)) != 0)
            {
                var buf = new byte[data.Length + 1];
                Buffer.BlockCopy(data, 0, buf, 1, data.Length);
                Array.Reverse(buf);
                return new BigInteger(buf);
            }

            return data.ToBigInteger();
#endif
        }

#if !NET
        public static byte[] ToByteArray(this BigInteger bigInt, bool isUnsigned = false, bool isBigEndian = false)
        {
            var data = bigInt.ToByteArray();

            if (isUnsigned && data[data.Length - 1] == 0)
            {
                data = data.Take(data.Length - 1);
            }

            if (isBigEndian)
            {
                Array.Reverse(data);
            }

            return data;
        }
#endif

#if !NET
        public static long GetBitLength(this BigInteger bigint)
        {
            // Taken from https://github.com/dotnet/runtime/issues/31308
            return (long)Math.Ceiling(BigInteger.Log(bigint.Sign < 0 ? -bigint : bigint + 1, 2));
        }
#endif

        // See https://github.com/dotnet/runtime/blob/9b57a265c7efd3732b035bade005561a04767128/src/libraries/Common/src/System/Security/Cryptography/KeyBlobHelpers.cs#L51
        public static byte[] ExportKeyParameter(this BigInteger value, int length)
        {
            var target = value.ToByteArray(isUnsigned: true, isBigEndian: true);

            // The BCL crypto is expecting exactly-sized byte arrays (sized to "length").
            // If our byte array is smaller than required, then size it up.
            // Otherwise, just return as is: if it is too large, we'll let the BCL throw the error.
            if (target.Length < length)
            {
                var correctlySized = new byte[length];
                Buffer.BlockCopy(target, 0, correctlySized, length - target.Length, target.Length);
                return correctlySized;
            }

            return target;
        }

        /// <summary>
        /// Sets a wait handle, swallowing any resulting <see cref="ObjectDisposedException"/>.
        /// Used in cases where set and dispose may race.
        /// </summary>
        /// <param name="waitHandle">The wait handle to set.</param>
        public static void SetIgnoringObjectDisposed(this EventWaitHandle waitHandle)
        {
            try
            {
                _ = waitHandle.Set();
            }
            catch (ObjectDisposedException)
            {
                // ODE intentionally ignored.
            }
        }

        internal static void ValidatePort(this uint value, [CallerArgumentExpression(nameof(value))] string argument = null)
        {
            if (value > IPEndPoint.MaxPort)
            {
                throw new ArgumentOutOfRangeException(argument,
                                                      string.Format(CultureInfo.InvariantCulture, "Specified value cannot be greater than {0}.", IPEndPoint.MaxPort));
            }
        }

        internal static void ValidatePort(this int value, [CallerArgumentExpression(nameof(value))] string argument = null)
        {
            if (value < IPEndPoint.MinPort)
            {
                throw new ArgumentOutOfRangeException(argument, string.Format(CultureInfo.InvariantCulture, "Specified value cannot be less than {0}.", IPEndPoint.MinPort));
            }

            if (value > IPEndPoint.MaxPort)
            {
                throw new ArgumentOutOfRangeException(argument, string.Format(CultureInfo.InvariantCulture, "Specified value cannot be greater than {0}.", IPEndPoint.MaxPort));
            }
        }

        /// <summary>
        /// Returns a specified number of contiguous bytes from a given offset.
        /// </summary>
        /// <param name="value">The array to return a number of bytes from.</param>
        /// <param name="offset">The zero-based offset in <paramref name="value"/> at which to begin taking bytes.</param>
        /// <param name="count">The number of bytes to take from <paramref name="value"/>.</param>
        /// <returns>
        /// A <see cref="byte"/> array that contains the specified number of bytes at the specified offset
        /// of the input array.
        /// </returns>
        /// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
        /// <remarks>
        /// When <paramref name="offset"/> is zero and <paramref name="count"/> equals the length of <paramref name="value"/>,
        /// then <paramref name="value"/> is returned.
        /// </remarks>
        public static byte[] Take(this byte[] value, int offset, int count)
        {
            ArgumentNullException.ThrowIfNull(value);

            if (count == 0)
            {
                return Array.Empty<byte>();
            }

            if (offset == 0 && value.Length == count)
            {
                return value;
            }

            var taken = new byte[count];
            Buffer.BlockCopy(value, offset, taken, 0, count);
            return taken;
        }

        /// <summary>
        /// Returns a specified number of contiguous bytes from the start of the specified byte array.
        /// </summary>
        /// <param name="value">The array to return a number of bytes from.</param>
        /// <param name="count">The number of bytes to take from <paramref name="value"/>.</param>
        /// <returns>
        /// A <see cref="byte"/> array that contains the specified number of bytes at the start of the input array.
        /// </returns>
        /// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
        /// <remarks>
        /// When <paramref name="count"/> equals the length of <paramref name="value"/>, then <paramref name="value"/>
        /// is returned.
        /// </remarks>
        public static byte[] Take(this byte[] value, int count)
        {
            ArgumentNullException.ThrowIfNull(value);

            if (count == 0)
            {
                return Array.Empty<byte>();
            }

            if (value.Length == count)
            {
                return value;
            }

            var taken = new byte[count];
            Buffer.BlockCopy(value, 0, taken, 0, count);
            return taken;
        }

        public static bool IsEqualTo(this byte[] left, byte[] right)
        {
            ArgumentNullException.ThrowIfNull(left);
            ArgumentNullException.ThrowIfNull(right);

            return left.AsSpan().SequenceEqual(right);
        }

        /// <summary>
        /// Trims the leading zero from a byte array.
        /// </summary>
        /// <param name="value">The value.</param>
        /// <returns>
        /// <paramref name="value"/> without leading zeros.
        /// </returns>
        public static byte[] TrimLeadingZeros(this byte[] value)
        {
            ArgumentNullException.ThrowIfNull(value);

            for (var i = 0; i < value.Length; i++)
            {
                if (value[i] == 0)
                {
                    continue;
                }

                // if the first byte is non-zero, then we return the byte array as is
                if (i == 0)
                {
                    return value;
                }

                var remainingBytes = value.Length - i;

                var cleaned = new byte[remainingBytes];
                Buffer.BlockCopy(value, i, cleaned, 0, remainingBytes);
                return cleaned;
            }

            return value;
        }

        /// <summary>
        /// Pads with leading zeros if needed.
        /// </summary>
        /// <param name="data">The data.</param>
        /// <param name="length">The length to pad to.</param>
        public static byte[] Pad(this byte[] data, int length)
        {
            if (length <= data.Length)
            {
                return data;
            }

            var newData = new byte[length];
            Buffer.BlockCopy(data, 0, newData, newData.Length - data.Length, data.Length);
            return newData;
        }

        public static byte[] Concat(this byte[] first, byte[] second)
        {
            if (first is null || first.Length == 0)
            {
                return second;
            }

            if (second is null || second.Length == 0)
            {
                return first;
            }

            var concat = new byte[first.Length + second.Length];
            Buffer.BlockCopy(first, 0, concat, 0, first.Length);
            Buffer.BlockCopy(second, 0, concat, first.Length, second.Length);
            return concat;
        }

        internal static bool IsConnected(this Socket socket)
        {
            if (socket is null)
            {
                return false;
            }

            return socket.Connected;
        }

#if !NET
        internal static bool TryAdd<TKey, TValue>(this Dictionary<TKey, TValue> dictionary, TKey key, TValue value)
        {
            if (!dictionary.ContainsKey(key))
            {
                dictionary.Add(key, value);
                return true;
            }

            return false;
        }

        internal static bool Remove<TKey, TValue>(this Dictionary<TKey, TValue> dictionary, TKey key, out TValue value)
        {
            if (dictionary.TryGetValue(key, out value))
            {
                _ = dictionary.Remove(key);
                return true;
            }

            value = default;
            return false;
        }

        internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index)
        {
            return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, arraySegment.Count - index);
        }

        internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index, int count)
        {
            return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, count);
        }

        internal static T[] ToArray<T>(this ArraySegment<T> arraySegment)
        {
            if (arraySegment.Count == 0)
            {
                return Array.Empty<T>();
            }

            var array = new T[arraySegment.Count];
            Array.Copy(arraySegment.Array, arraySegment.Offset, array, 0, arraySegment.Count);
            return array;
        }

#pragma warning disable CA1859 // Use concrete types for improved performance
        internal static void ReadExactly(this Stream stream, byte[] buffer, int offset, int count)
#pragma warning restore CA1859
        {
            var totalRead = 0;

            while (totalRead < count)
            {
                var read = stream.Read(buffer, offset + totalRead, count - totalRead);
                if (read == 0)
                {
                    throw new EndOfStreamException();
                }

                totalRead += read;
            }
        }

        internal static Task<T> WaitAsync<T>(this Task<T> task, CancellationToken cancellationToken)
        {
            if (task.IsCompleted || !cancellationToken.CanBeCanceled)
            {
                return task;
            }

            return WaitCore();

            async Task<T> WaitCore()
            {
                TaskCompletionSource<T> tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

                using var reg = cancellationToken.Register(
                    () => tcs.TrySetCanceled(cancellationToken),
                    useSynchronizationContext: false);

                var completedTask = await Task.WhenAny(task, tcs.Task).ConfigureAwait(false);

                return await completedTask.ConfigureAwait(false);
            }
        }

        extension(Array)
        {
            internal static int MaxLength
            {
                get { return 0X7FFFFFC7; }
            }
        }

        extension(Task t)
        {
            internal bool IsCompletedSuccessfully
            {
                get { return t.Status == TaskStatus.RanToCompletion; }
            }
        }
#endif
    }
}
